#rm(list = ls())
library(tidyverse)
## ── Attaching packages ─────────────────────────────────────── tidyverse 1.3.1 ──
## ✔ ggplot2 3.3.6     ✔ purrr   0.3.4
## ✔ tibble  3.1.7     ✔ dplyr   1.0.9
## ✔ tidyr   1.2.0     ✔ stringr 1.4.0
## ✔ readr   2.1.2     ✔ forcats 0.5.1
## ── Conflicts ────────────────────────────────────────── tidyverse_conflicts() ──
## ✖ dplyr::filter() masks stats::filter()
## ✖ dplyr::lag()    masks stats::lag()
library(future)
library(ggthemes)

In this notebook we will test the performance of varKode to distinguish species of Malpighiaceae and figure out the best parameters for training a dataset.

Kmer size and amount of data

To start, we produced images from different numbers of kmers. We can suppose that shorter kmers will offer lower resolution to resolve species, but they will also create smaller files that require less computation. Here we will test whether images based on longer kmers result in higher accuracy. As an example, here are images produced from 200Mb for the same sample, but different kmer sizes (5-9):

knitr::include_graphics(paste0('images_',5:9,'/S_bannisterioides+S-91_00200000K.png'))

We also used different amounts of data to produce images, since we want to figure out the lowest amount needed to distinguish species. With less data, figures get more noisy since chance plays a bigger role in the observed kmer frequencies. This should be more severe for larger kmer sizes, since each kmer will be more unique in the genome.

For example, images for 5-mer for the same sample as above, for 500Kb and 200Mb:

knitr::include_graphics(paste0('images_6/S_bannisterioides+S-91_00',c('000500','200000'),'K.png'))

The same, but for 8-mers:

knitr::include_graphics(paste0('images_8/S_bannisterioides+S-91_00',c('000500','200000'),'K.png'))

Now that we understand the differences between images, let’s understand the effect in accuracy. We previously trained CNN models to recognize images for a combination of kmer sizes and amount of data, with 10 replicates for each combination. In each replicate, we kept 3 randomly chosen samples per species as a validation set and checked the accuracy of the trained model in guessing the species of these samples, for different amounts of data used for the validation sample. What we want is to find:

1 - The lowest kmer size to produce high accuracy

2 - The lowest amount of data needed

3 - Whether the amount of data used for training and for querying must be similar.

The results of these simulations were saved as a csv table, let’s load it (ignoring the first, index column):

df = read_csv('kmerSize_VS_bp.csv')[-1]
## New names:
## Rows: 4500 Columns: 11
## ── Column specification
## ──────────────────────────────────────────────────────── Delimiter: "," chr
## (3): bp_training, samples_training, samples_valid dbl (8): ...1, kmer_size,
## replicate, bp_valid, n_samp_training, n_samp_valid...
## ℹ Use `spec()` to retrieve the full column specification for this data. ℹ
## Specify the column types or set `show_col_types = FALSE` to quiet this message.
## • `` -> `...1`
df

Now let’s make sure bp_training and bp_valid are treated as ordered factors for nice plotting:

not_all = as.character(sort(as.integer(unique(df$bp_training[!str_detect(df$bp_training,'\\|',)]))/1e6))
ordered_levels = c(not_all,'all')

df = df %>%
  mutate(bp_training = as.character(as.integer(bp_training)/1e6) ) %>%
  mutate(bp_training = replace_na(bp_training, 'all')) %>%
  mutate(bp_training = factor(bp_training,
                              levels = ordered_levels, 
                              ordered = TRUE),
         bp_valid = factor(as.character(as.integer(bp_valid)/1e6), 
                           levels=ordered_levels, 
                           ordered = TRUE),
         kmer_size = factor(as.character(kmer_size),
                            levels = as.character(sort(unique(kmer_size))),
                            ordered = TRUE
                            )
         )
## Warning in mask$eval_all_mutate(quo): NAs introduced by coercion
df

Now we can plot:

kmer_labeller = as_labeller(function(value){
  return(paste0('kmer length:',value))
})

ggplot(df) +
  geom_jitter(aes(x = bp_training, y = bp_valid, color = valid_acc)) +
  scale_color_viridis_c('Validation\naccuracy', option = 'inferno', limits = c(0,1)) +
  facet_grid(~kmer_size, labeller = kmer_labeller) +
  coord_equal() +
  xlab('Data in training images (Mb)') +
  ylab('Data in validation images (Mb)')

Now a version with averaged accuracy

p = df %>%
  group_by(kmer_size,bp_training,bp_valid) %>%
  summarize(valid_acc = mean(valid_acc)) %>%
  ggplot(aes(x = bp_training, y = bp_valid, fill = valid_acc)) +
  geom_raster() +
  #geom_text(aes(label=sprintf(100*valid_acc,fmt='%2.0f')),size=4.5*5/14) +
  scale_fill_viridis_c('Average\nvalidation\naccuracy', option = 'magma', limits = c(0,1),labels=scales::percent) +
  facet_grid(~kmer_size, labeller = kmer_labeller) +
  coord_equal() +
  xlab('Data in training images (Mb)') +
  ylab('Data in validation images (Mb)') +
  theme_few(base_size = 6)
## `summarise()` has grouped output by 'kmer_size', 'bp_training'. You can
## override using the `.groups` argument.
p

dir.create('paper_images')
## Warning in dir.create("paper_images"): 'paper_images' already exists
ggsave(filename = 'kmerlen_vs_accuracy.png',plot =p,device='png',path = 'paper_images',width = 22,height = 5,units = 'cm',dpi = 2400)
means = df %>%
  filter(bp_training %in% c('0.5','1','200','all')) %>%
  #filter(bp_valid %in% c('50','100','200','all')) %>%
  filter(bp_valid %in% c('2','5','10','20','50','100','200')) %>%
  group_by(bp_training,kmer_size) %>%
  summarise(Int=median(valid_acc))
## `summarise()` has grouped output by 'bp_training'. You can override using the
## `.groups` argument.
df %>%
  filter(bp_training %in% c('0.5','1','200','all')) %>%
  filter(bp_valid %in% c('2','5','10','20','50','100','200')) %>%
  #filter(bp_valid %in% c('50','100','200','all')) %>%
  ggplot(aes(x=valid_acc)) +
  geom_histogram(aes(x=valid_acc)) +
  facet_grid(kmer_size~bp_training) +
  geom_vline(data = means, aes(xintercept = Int))
## `stat_bin()` using `bins = 30`. Pick better value with `binwidth`.

So it seems that the smallest kmer sizes never result in very high accuracy, and the largest kmer sizes result in high accuracy for higher amounts of data, but lower accuracy for lower amounts. It seems that a kmer size of 7 is a good balance, and that training using images of different sizes helps in being more robust to the amount of data used to produce validation images.

As little as 1Mb produces moderately accurate results for kmer size 7 or below.

Can we quantify what is different about images produced with different data amounts?

images = c(list.files(path='images_5',pattern='.png', recursive = T, full.names = T),
           list.files(path='images_6',pattern='.png', recursive = T, full.names = T),
           list.files(path='images_7',pattern='.png', recursive = T, full.names = T),
           list.files(path='images_8',pattern='.png', recursive = T, full.names = T),
           list.files(path='images_9',pattern='.png', recursive = T, full.names = T))

nkmers = function(k){ #from https://bioinfologics.github.io/post/2018/09/17/k-mer-counting-part-i-introduction/
  (4^k + (1 - k%%2) * 4^(k/2))/2
}

get_sd = function(path){
  k = as.integer(gsub('.+_([0-9])/.+','\\1', path))
  taxon = gsub('.+/(.+)\\+.+','\\1', path)
  sample = gsub('.+\\+(S-[0-9]+)_.+','\\1', path)
  Mbp = as.integer(gsub('.+_([0-9]{8})K.+','\\1', path)) / 1000
  
  x = sort(png::readPNG(path))
  x = x[(length(x)-nkmers(k)+1):length(x)]
  sd_counts = sd(table(x))
  
  data.frame(k = k, taxon = taxon, sample = sample, Mbp = Mbp, sd_counts=sd_counts)
  
}

plan(multisession(workers = 4))
df = furrr::future_map_dfr(images,get_sd)
plan(sequential)


df
ggplot(df) +
  geom_line(aes(x=Mbp, y=sd_counts,color=sample)) +
  facet_wrap(as.factor(k)~.,scales = 'free') +
  scale_color_discrete(guide='none') +
  scale_y_log10() +
  scale_x_log10()

Training parameters

Now we will check the results of using different training parameters: - model pretraining - augmentation (CutMix or MixUp) - Label Smoothing - model architecture - lighting transformations

Let’s read the data and prepare for plotting:

df = read_csv('training_params.csv')[-1] %>%
  mutate(bp_valid = factor(as.character(as.integer(bp_valid)/1e6), 
                           levels = sort(unique(bp_valid/1e6)), 
                           ordered = TRUE),
         augmentation = ifelse(str_detect(callback,'CutMix'),'CutMix',
                               ifelse(str_detect(callback,'MixUp'),'MixUp',
                                      'None')
                               ),
         augmentation = factor(augmentation, levels = c('None','MixUp','CutMix'),ordered = F),
         aug = str_replace(augmentation,'None',''),
         lablsmth= ifelse(label_smoothing,
                                  'label Smoothing',
                                  ''),
         pretr = ifelse(pretrained,
                             'pretrained',
                             ''
                             ),
         transformations = ifelse(trans,
                             'with_transforms',
                             ''
                             ),
         parameters = paste(arch,pretr,lablsmth,aug,transformations,sep=',') %>%
           str_replace_all(',{2,}',',') %>%
           str_remove_all('^,|,$') %>%
           str_replace_all('^$','None') %>%
           fct_reorder(valid_acc, mean)
  )
## New names:
## Rows: 30240 Columns: 16
## ── Column specification
## ──────────────────────────────────────────────────────── Delimiter: "," chr
## (5): bp_training, samples_training, samples_valid, callback, arch dbl (8):
## ...1, kmer_size, replicate, bp_valid, n_samp_training, n_samp_valid... lgl (3):
## label_smoothing, pretrained, trans
## ℹ Use `spec()` to retrieve the full column specification for this data. ℹ
## Specify the column types or set `show_col_types = FALSE` to quiet this message.
## • `` -> `...1`
df 

Now we can plot the effect of parameters. There are clearly some models that do much better than others:

ggplot(df, aes(x = parameters, y = valid_acc)) +
  #geom_boxplot() +
  #geom_violin(adjust=1.5) +
  geom_jitter(aes(color = bp_valid),height = 0.005) +
  scale_color_viridis_d(option='turbo',begin = 0.1, end=0.9) +
  #facet_wrap(~bp_valid) +
  theme(axis.text.x = element_text(hjust = 1, angle = 45))

What are the 20 most accurate models?

ggplot(filter(df, parameters %in% tail(levels(df$parameters),20)), aes(x = parameters, y = valid_acc)) +
  #geom_boxplot() +
  #geom_violin(adjust=1.5) +
  geom_jitter(aes(color = bp_valid),height = 0.005) +
  scale_color_viridis_d(option='turbo',begin = 0.1, end=0.9) +
  #facet_wrap(~bp_valid) +
  theme(axis.text.x = element_text(hjust = 1, angle = 45))

Let’s plot by architecture:

p = ggplot(mutate(df, arch = fct_reorder(arch,valid_acc)), 
       aes(x = arch, y = valid_acc, color=bp_valid)) +
  #geom_boxplot() +
  #geom_violin(adjust=1.5) +
  geom_jitter(aes(color = bp_valid),height = 0.005, size = 0.1, alpha = 0.1, shape = 16) +
  stat_summary(fun = mean, geom = 'crossbar', size = 0.05, show.legend=FALSE) +
  scale_color_viridis_d(option='turbo',begin = 0.1, end=0.9, name = 'Mbp in validation\nimages', 
                        guide = guide_legend(override.aes = list(alpha = 1))) +
  scale_y_continuous(labels = scales::percent, name = 'Validation Accuracy') +
  xlab('Model architecture') +
  #facet_wrap(~bp_valid) +
  theme_few(base_size = 6) +
  theme(axis.text.x = element_text(hjust = 1, angle = 45),
        legend.key.size = unit(0.2, "cm"))

p 

ggsave(filename = 'architecture.png',plot =p,device='png',path = 'paper_images',width = 5,height = 5,units = 'cm',dpi = 2400)

Now by pretrained:

p = ggplot(mutate(df, pretr = fct_reorder(pretr,valid_acc)), 
       aes(x = pretr, y = valid_acc, color=bp_valid)) +
  #geom_boxplot() +
  #geom_violin() +
  geom_jitter(aes(color = bp_valid),height = 0.005, size = 0.05, alpha = 0.1, shape = 16) +
  stat_summary(fun = mean, geom = 'crossbar', size = 0.05) +
  scale_x_discrete(labels = c('pre-trained','random'), name = 'Model pretraining') +
  scale_color_viridis_d(option='turbo',begin = 0.1, end=0.9, name = 'Mbp in validation\nimages') +
  scale_y_continuous(labels = scales::percent, name = 'Validation Accuracy') +
  #facet_wrap(~bp_valid) +
  theme_few(base_size = 6) +
  theme(axis.text.x = element_text(hjust = 1, angle = 45),
        legend.position = 'none'
        )

p

ggsave(filename = 'pretraining.png',plot =p,device='png',path = 'paper_images',width = 3,height = 5,units = 'cm',dpi = 2400)

Now by label smoothing:

p = ggplot(mutate(df, lablsmth = fct_reorder(lablsmth,valid_acc)), 
         aes(x = lablsmth, y = valid_acc, color=bp_valid)) +
  #geom_boxplot() +
  #geom_violin(adjust=1.5) +
  geom_jitter(aes(color = bp_valid),height = 0.005, size = 0.05, alpha = 0.1, shape = 16) +
  stat_summary(fun = mean, geom = 'crossbar', size = 0.05) +
  scale_x_discrete(labels = c('No','Yes'), name = 'Label smoothing') +
  scale_color_viridis_d(option='turbo',begin = 0.1, end=0.9, name = 'Mbp in validation\nimages') +
  scale_y_continuous(labels = scales::percent, name = 'Validation Accuracy') +
  #facet_wrap(~bp_valid) +
  theme_few(base_size = 6) +
  theme(axis.text.x = element_text(hjust = 1, angle = 45),
        legend.position = 'none'
        )

p

ggsave(filename = 'labelsmoothing.png',plot =p,device='png',path = 'paper_images',width = 3,height = 5,units = 'cm',dpi = 2400)

Now by CutMix/MixUp augmentations:

p = ggplot(mutate(df, augmentation = fct_reorder(augmentation,valid_acc)), 
       aes(x = augmentation, y = valid_acc, color = bp_valid)) +
  #geom_boxplot() +
  #geom_violin(adjust=1.5) +
  geom_jitter(aes(color = bp_valid),height = 0.005, size = 0.05, alpha = 0.1, shape = 16) +
  stat_summary(fun = mean, geom = 'crossbar', size = 0.05) +
  scale_x_discrete(name = 'Augmentation') +
  scale_color_viridis_d(option='turbo',begin = 0.1, end=0.9, name = 'Mbp in validation\nimages') +
  scale_y_continuous(labels = scales::percent, name = 'Validation Accuracy') +
  theme_few(base_size = 6) +
  theme(axis.text.x = element_text(hjust = 1, angle = 45),
        legend.position = 'none'
        )

p

ggsave(filename = 'augmentations.png',plot =p,device='png',path = 'paper_images',width = 3,height = 5,units = 'cm',dpi = 2400)

Finally,by lighting transforms:

p = ggplot(mutate(df, transformations = fct_reorder(transformations,valid_acc)), 
       aes(x = transformations, y = valid_acc, color=bp_valid)) +
  #geom_boxplot() +
  #geom_violin(adjust=1.5) +
  geom_jitter(aes(color = bp_valid),height = 0.005, size = 0.05, alpha = 0.1, shape = 16) +
  stat_summary(fun = mean, geom = 'crossbar', size = 0.05) +
  scale_x_discrete(name = 'Lighting transforms', labels = c('No','Yes')) +
  scale_color_viridis_d(option='turbo',begin = 0.1, end=0.9, name = 'Mbp in validation\nimages') +
  scale_y_continuous(labels = scales::percent, name = 'Validation Accuracy') +
  theme_few(base_size = 6) +
  theme(axis.text.x = element_text(hjust = 1, angle = 45),
        legend.position = 'none'
        )

p

ggsave(filename = 'lighting.png',plot =p,device='png',path = 'paper_images',width = 3,height = 5,units = 'cm',dpi = 2400)

Let’s try a linear model to check which combination is best:

full_model = lm(asin(valid_acc)~arch*trans*pretrained*augmentation*label_smoothing*bp_valid, data = df)
plot(full_model)

reduced_model = step(lm(asin(valid_acc)~1, data = df), 
                     scope = list(lower = formula(asin(valid_acc)~1), 
                                  upper = formula(asin(valid_acc)~arch*trans*pretrained*augmentation*label_smoothing*bp_valid)
                                  ),
                     direction = 'forward')

The best model is quite complex with some interactions

reduced_model
## 
## Call:
## lm(formula = asin(valid_acc) ~ bp_valid + pretrained + arch + 
##     augmentation + label_smoothing + trans + pretrained:arch + 
##     bp_valid:pretrained + bp_valid:arch + arch:augmentation + 
##     bp_valid:augmentation + pretrained:augmentation + augmentation:label_smoothing + 
##     pretrained:label_smoothing + arch:trans + augmentation:trans + 
##     bp_valid:pretrained:arch + pretrained:arch:augmentation + 
##     pretrained:augmentation:label_smoothing + arch:augmentation:trans + 
##     bp_valid:pretrained:augmentation, data = df)
## 
## Coefficients:
##                                               (Intercept)  
##                                                 6.561e-01  
##                                                bp_valid.L  
##                                                 3.114e-01  
##                                                bp_valid.Q  
##                                                -1.833e-01  
##                                                bp_valid.C  
##                                                 5.212e-02  
##                                                bp_valid^4  
##                                                 1.349e-02  
##                                                bp_valid^5  
##                                                -2.865e-02  
##                                                bp_valid^6  
##                                                 1.188e-02  
##                                                bp_valid^7  
##                                                 9.771e-03  
##                                                bp_valid^8  
##                                                -5.580e-03  
##                                            pretrainedTRUE  
##                                                 6.445e-02  
##                                   archig_resnext101_32x8d  
##                                                 4.786e-01  
##                                            archresnet101d  
##                                                 3.779e-01  
##                                             archresnet18d  
##                                                 3.050e-01  
##                                              archresnet50  
##                                                 4.784e-01  
##                                             archresnet50d  
##                                                 3.871e-01  
##                                       archwide_resnet50_2  
##                                                 5.112e-01  
##                                         augmentationMixUp  
##                                                 6.400e-02  
##                                        augmentationCutMix  
##                                                -2.131e-02  
##                                       label_smoothingTRUE  
##                                                 3.604e-02  
##                                                 transTRUE  
##                                                 1.098e-02  
##                    pretrainedTRUE:archig_resnext101_32x8d  
##                                                -3.675e-01  
##                             pretrainedTRUE:archresnet101d  
##                                                -3.555e-01  
##                              pretrainedTRUE:archresnet18d  
##                                                -2.542e-01  
##                               pretrainedTRUE:archresnet50  
##                                                -3.452e-01  
##                              pretrainedTRUE:archresnet50d  
##                                                -3.491e-01  
##                        pretrainedTRUE:archwide_resnet50_2  
##                                                -4.532e-01  
##                                 bp_valid.L:pretrainedTRUE  
##                                                 2.358e-01  
##                                 bp_valid.Q:pretrainedTRUE  
##                                                -4.133e-02  
##                                 bp_valid.C:pretrainedTRUE  
##                                                -5.603e-02  
##                                 bp_valid^4:pretrainedTRUE  
##                                                 4.519e-02  
##                                 bp_valid^5:pretrainedTRUE  
##                                                -1.260e-02  
##                                 bp_valid^6:pretrainedTRUE  
##                                                -1.502e-02  
##                                 bp_valid^7:pretrainedTRUE  
##                                                 5.130e-03  
##                                 bp_valid^8:pretrainedTRUE  
##                                                 3.372e-03  
##                        bp_valid.L:archig_resnext101_32x8d  
##                                                -2.926e-02  
##                        bp_valid.Q:archig_resnext101_32x8d  
##                                                -8.490e-03  
##                        bp_valid.C:archig_resnext101_32x8d  
##                                                 1.977e-02  
##                        bp_valid^4:archig_resnext101_32x8d  
##                                                -8.721e-03  
##                        bp_valid^5:archig_resnext101_32x8d  
##                                                -2.483e-03  
##                        bp_valid^6:archig_resnext101_32x8d  
##                                                 8.711e-03  
##                        bp_valid^7:archig_resnext101_32x8d  
##                                                -1.260e-02  
##                        bp_valid^8:archig_resnext101_32x8d  
##                                                 2.109e-04  
##                                 bp_valid.L:archresnet101d  
##                                                -1.653e-01  
##                                 bp_valid.Q:archresnet101d  
##                                                 7.381e-02  
##                                 bp_valid.C:archresnet101d  
##                                                 9.417e-04  
##                                 bp_valid^4:archresnet101d  
##                                                -2.948e-02  
##                                 bp_valid^5:archresnet101d  
##                                                 2.877e-02  
##                                 bp_valid^6:archresnet101d  
##                                                -1.376e-02  
##                                 bp_valid^7:archresnet101d  
##                                                -5.532e-03  
##                                 bp_valid^8:archresnet101d  
##                                                -4.737e-03  
##                                  bp_valid.L:archresnet18d  
##                                                -1.193e-01  
##                                  bp_valid.Q:archresnet18d  
##                                                 5.106e-02  
##                                  bp_valid.C:archresnet18d  
##                                                 1.054e-02  
##                                  bp_valid^4:archresnet18d  
##                                                -3.777e-02  
##                                  bp_valid^5:archresnet18d  
##                                                 2.669e-02  
##                                  bp_valid^6:archresnet18d  
##                                                -5.637e-03  
##                                  bp_valid^7:archresnet18d  
##                                                -1.105e-02  
##                                  bp_valid^8:archresnet18d  
##                                                 1.530e-03  
##                                   bp_valid.L:archresnet50  
##                                                -4.650e-02  
##                                   bp_valid.Q:archresnet50  
##                                                -6.761e-03  
##                                   bp_valid.C:archresnet50  
##                                                 2.939e-02  
##                                   bp_valid^4:archresnet50  
##                                                -1.920e-02  
##                                   bp_valid^5:archresnet50  
##                                                 2.184e-03  
##                                   bp_valid^6:archresnet50  
##                                                 1.335e-02  
##                                   bp_valid^7:archresnet50  
##                                                -1.706e-02  
##                                   bp_valid^8:archresnet50  
##                                                 2.617e-03  
##                                  bp_valid.L:archresnet50d  
##                                                -1.553e-01  
##                                  bp_valid.Q:archresnet50d  
##                                                 7.842e-02  
##                                  bp_valid.C:archresnet50d  
##                                                -4.425e-03  
##                                  bp_valid^4:archresnet50d  
##                                                -2.616e-02  
##                                  bp_valid^5:archresnet50d  
##                                                 2.930e-02  
##                                  bp_valid^6:archresnet50d  
##                                                -1.431e-02  
##                                  bp_valid^7:archresnet50d  
##                                                -3.372e-03  
##                                  bp_valid^8:archresnet50d  
##                                                -5.996e-03  
##                            bp_valid.L:archwide_resnet50_2  
##                                                -3.658e-02  
##                            bp_valid.Q:archwide_resnet50_2  
##                                                 1.171e-03  
##                            bp_valid.C:archwide_resnet50_2  
##                                                 1.877e-02  
##                            bp_valid^4:archwide_resnet50_2  
##                                                -1.304e-02  
##                            bp_valid^5:archwide_resnet50_2  
##                                                -1.398e-02  
##                            bp_valid^6:archwide_resnet50_2  
##                                                 2.189e-02  
##                            bp_valid^7:archwide_resnet50_2  
##                                                -1.898e-02  
##                            bp_valid^8:archwide_resnet50_2  
##                                                -9.812e-04  
##                 archig_resnext101_32x8d:augmentationMixUp  
##                                                -1.766e-02  
##                          archresnet101d:augmentationMixUp  
##                                                 5.130e-02  
##                           archresnet18d:augmentationMixUp  
##                                                -9.741e-03  
##                            archresnet50:augmentationMixUp  
##                                                -2.336e-02  
##                           archresnet50d:augmentationMixUp  
##                                                 2.380e-02  
##                     archwide_resnet50_2:augmentationMixUp  
##                                                -5.815e-02  
##                archig_resnext101_32x8d:augmentationCutMix  
##                                                 9.295e-02  
##                         archresnet101d:augmentationCutMix  
##                                                 1.288e-01  
##                          archresnet18d:augmentationCutMix  
##                                                 6.007e-02  
##                           archresnet50:augmentationCutMix  
##                                                 7.041e-02  
##                          archresnet50d:augmentationCutMix  
##                                                 1.017e-01  
##                    archwide_resnet50_2:augmentationCutMix  
##                                                 3.642e-02  
##                              bp_valid.L:augmentationMixUp  
##                                                 3.908e-03  
##                              bp_valid.Q:augmentationMixUp  
##                                                -6.958e-03  
##                              bp_valid.C:augmentationMixUp  
##                                                 4.087e-03  
##                              bp_valid^4:augmentationMixUp  
##                                                 2.047e-04  
##                              bp_valid^5:augmentationMixUp  
##                                                 5.062e-03  
##                              bp_valid^6:augmentationMixUp  
##                                                -4.185e-03  
##                              bp_valid^7:augmentationMixUp  
##                                                 2.057e-03  
##                              bp_valid^8:augmentationMixUp  
##                                                 4.019e-03  
##                             bp_valid.L:augmentationCutMix  
##                                                -1.606e-02  
##                             bp_valid.Q:augmentationCutMix  
##                                                 4.800e-03  
##                             bp_valid.C:augmentationCutMix  
##                                                 7.649e-05  
##                             bp_valid^4:augmentationCutMix  
##                                                -1.092e-04  
##                             bp_valid^5:augmentationCutMix  
##                                                 1.038e-02  
##                             bp_valid^6:augmentationCutMix  
##                                                -7.533e-03  
##                             bp_valid^7:augmentationCutMix  
##                                                 2.572e-03  
##                             bp_valid^8:augmentationCutMix  
##                                                 4.894e-03  
##                          pretrainedTRUE:augmentationMixUp  
##                                                -5.293e-02  
##                         pretrainedTRUE:augmentationCutMix  
##                                                 2.794e-02  
##                     augmentationMixUp:label_smoothingTRUE  
##                                                -5.032e-02  
##                    augmentationCutMix:label_smoothingTRUE  
##                                                -5.159e-02  
##                        pretrainedTRUE:label_smoothingTRUE  
##                                                -5.352e-02  
##                         archig_resnext101_32x8d:transTRUE  
##                                                 1.860e-03  
##                                  archresnet101d:transTRUE  
##                                                 1.813e-03  
##                                   archresnet18d:transTRUE  
##                                                 1.203e-03  
##                                    archresnet50:transTRUE  
##                                                -9.509e-03  
##                                   archresnet50d:transTRUE  
##                                                -4.479e-03  
##                             archwide_resnet50_2:transTRUE  
##                                                -1.943e-02  
##                               augmentationMixUp:transTRUE  
##                                                -1.450e-02  
##                              augmentationCutMix:transTRUE  
##                                                -5.982e-03  
##         bp_valid.L:pretrainedTRUE:archig_resnext101_32x8d  
##                                                -2.637e-02  
##         bp_valid.Q:pretrainedTRUE:archig_resnext101_32x8d  
##                                                -2.021e-02  
##         bp_valid.C:pretrainedTRUE:archig_resnext101_32x8d  
##                                                 1.837e-02  
##         bp_valid^4:pretrainedTRUE:archig_resnext101_32x8d  
##                                                -1.841e-02  
##         bp_valid^5:pretrainedTRUE:archig_resnext101_32x8d  
##                                                 2.453e-02  
##         bp_valid^6:pretrainedTRUE:archig_resnext101_32x8d  
##                                                -9.192e-03  
##         bp_valid^7:pretrainedTRUE:archig_resnext101_32x8d  
##                                                -3.420e-03  
##         bp_valid^8:pretrainedTRUE:archig_resnext101_32x8d  
##                                                -1.808e-02  
##                  bp_valid.L:pretrainedTRUE:archresnet101d  
##                                                 1.349e-01  
##                  bp_valid.Q:pretrainedTRUE:archresnet101d  
##                                                -7.756e-02  
##                  bp_valid.C:pretrainedTRUE:archresnet101d  
##                                                 3.039e-02  
##                  bp_valid^4:pretrainedTRUE:archresnet101d  
##                                                 6.843e-03  
##                  bp_valid^5:pretrainedTRUE:archresnet101d  
##                                                -1.434e-02  
##                  bp_valid^6:pretrainedTRUE:archresnet101d  
##                                                -1.204e-03  
##                  bp_valid^7:pretrainedTRUE:archresnet101d  
##                                                -6.390e-03  
##                  bp_valid^8:pretrainedTRUE:archresnet101d  
##                                                 1.045e-02  
##                   bp_valid.L:pretrainedTRUE:archresnet18d  
##                                                 1.673e-01  
##                   bp_valid.Q:pretrainedTRUE:archresnet18d  
##                                                -7.756e-02  
##                   bp_valid.C:pretrainedTRUE:archresnet18d  
##                                                 2.741e-02  
##                   bp_valid^4:pretrainedTRUE:archresnet18d  
##                                                 7.421e-04  
##                   bp_valid^5:pretrainedTRUE:archresnet18d  
##                                                -6.923e-04  
##                   bp_valid^6:pretrainedTRUE:archresnet18d  
##                                                -1.453e-02  
##                   bp_valid^7:pretrainedTRUE:archresnet18d  
##                                                 8.352e-03  
##                   bp_valid^8:pretrainedTRUE:archresnet18d  
##                                                -4.905e-03  
##                    bp_valid.L:pretrainedTRUE:archresnet50  
##                                                -1.594e-02  
##                    bp_valid.Q:pretrainedTRUE:archresnet50  
##                                                 9.914e-03  
##                    bp_valid.C:pretrainedTRUE:archresnet50  
##                                                -1.984e-02  
##                    bp_valid^4:pretrainedTRUE:archresnet50  
##                                                -1.313e-02  
##                    bp_valid^5:pretrainedTRUE:archresnet50  
##                                                 3.741e-02  
##                    bp_valid^6:pretrainedTRUE:archresnet50  
##                                                 5.299e-04  
##                    bp_valid^7:pretrainedTRUE:archresnet50  
##                                                -3.754e-03  
##                    bp_valid^8:pretrainedTRUE:archresnet50  
##                                                -1.681e-03  
##                   bp_valid.L:pretrainedTRUE:archresnet50d  
##                                                 1.205e-01  
##                   bp_valid.Q:pretrainedTRUE:archresnet50d  
##                                                -1.114e-01  
##                   bp_valid.C:pretrainedTRUE:archresnet50d  
##                                                 4.579e-02  
##                   bp_valid^4:pretrainedTRUE:archresnet50d  
##                                                 9.102e-03  
##                   bp_valid^5:pretrainedTRUE:archresnet50d  
##                                                -3.293e-02  
##                   bp_valid^6:pretrainedTRUE:archresnet50d  
##                                                 9.293e-03  
##                   bp_valid^7:pretrainedTRUE:archresnet50d  
##                                                 2.861e-02  
##                   bp_valid^8:pretrainedTRUE:archresnet50d  
##                                                 8.449e-03  
##             bp_valid.L:pretrainedTRUE:archwide_resnet50_2  
##                                                -2.731e-02  
##             bp_valid.Q:pretrainedTRUE:archwide_resnet50_2  
##                                                 2.152e-02  
##             bp_valid.C:pretrainedTRUE:archwide_resnet50_2  
##                                                -1.329e-02  
##             bp_valid^4:pretrainedTRUE:archwide_resnet50_2  
##                                                -2.174e-02  
##             bp_valid^5:pretrainedTRUE:archwide_resnet50_2  
##                                                 5.888e-02  
##             bp_valid^6:pretrainedTRUE:archwide_resnet50_2  
##                                                -4.246e-02  
##             bp_valid^7:pretrainedTRUE:archwide_resnet50_2  
##                                                 3.718e-02  
##             bp_valid^8:pretrainedTRUE:archwide_resnet50_2  
##                                                -3.049e-02  
##  pretrainedTRUE:archig_resnext101_32x8d:augmentationMixUp  
##                                                 6.789e-02  
##           pretrainedTRUE:archresnet101d:augmentationMixUp  
##                                                -3.619e-02  
##            pretrainedTRUE:archresnet18d:augmentationMixUp  
##                                                -3.443e-03  
##             pretrainedTRUE:archresnet50:augmentationMixUp  
##                                                 2.548e-02  
##            pretrainedTRUE:archresnet50d:augmentationMixUp  
##                                                 5.615e-04  
##      pretrainedTRUE:archwide_resnet50_2:augmentationMixUp  
##                                                 4.066e-02  
## pretrainedTRUE:archig_resnext101_32x8d:augmentationCutMix  
##                                                -9.468e-02  
##          pretrainedTRUE:archresnet101d:augmentationCutMix  
##                                                -1.122e-01  
##           pretrainedTRUE:archresnet18d:augmentationCutMix  
##                                                -8.547e-02  
##            pretrainedTRUE:archresnet50:augmentationCutMix  
##                                                -8.759e-02  
##           pretrainedTRUE:archresnet50d:augmentationCutMix  
##                                                -6.982e-02  
##     pretrainedTRUE:archwide_resnet50_2:augmentationCutMix  
##                                                -6.689e-02  
##      pretrainedTRUE:augmentationMixUp:label_smoothingTRUE  
##                                                 5.111e-02  
##     pretrainedTRUE:augmentationCutMix:label_smoothingTRUE  
##                                                 5.191e-02  
##       archig_resnext101_32x8d:augmentationMixUp:transTRUE  
##                                                -5.656e-03  
##                archresnet101d:augmentationMixUp:transTRUE  
##                                                 2.068e-02  
##                 archresnet18d:augmentationMixUp:transTRUE  
##                                                 1.244e-02  
##                  archresnet50:augmentationMixUp:transTRUE  
##                                                 9.801e-03  
##                 archresnet50d:augmentationMixUp:transTRUE  
##                                                -9.574e-04  
##           archwide_resnet50_2:augmentationMixUp:transTRUE  
##                                                 1.794e-02  
##      archig_resnext101_32x8d:augmentationCutMix:transTRUE  
##                                                 8.469e-03  
##               archresnet101d:augmentationCutMix:transTRUE  
##                                                -8.739e-03  
##                archresnet18d:augmentationCutMix:transTRUE  
##                                                -2.055e-02  
##                 archresnet50:augmentationCutMix:transTRUE  
##                                                 1.405e-02  
##                archresnet50d:augmentationCutMix:transTRUE  
##                                                -1.698e-04  
##          archwide_resnet50_2:augmentationCutMix:transTRUE  
##                                                 2.349e-03  
##               bp_valid.L:pretrainedTRUE:augmentationMixUp  
##                                                -5.862e-03  
##               bp_valid.Q:pretrainedTRUE:augmentationMixUp  
##                                                -9.938e-03  
##               bp_valid.C:pretrainedTRUE:augmentationMixUp  
##                                                 1.798e-02  
##               bp_valid^4:pretrainedTRUE:augmentationMixUp  
##                                                 5.561e-04  
##               bp_valid^5:pretrainedTRUE:augmentationMixUp  
##                                                -7.104e-03  
##               bp_valid^6:pretrainedTRUE:augmentationMixUp  
##                                                 2.117e-03  
##               bp_valid^7:pretrainedTRUE:augmentationMixUp  
##                                                -4.217e-03  
##               bp_valid^8:pretrainedTRUE:augmentationMixUp  
##                                                 1.170e-03  
##              bp_valid.L:pretrainedTRUE:augmentationCutMix  
##                                                -4.333e-02  
##              bp_valid.Q:pretrainedTRUE:augmentationCutMix  
##                                                 1.675e-03  
##              bp_valid.C:pretrainedTRUE:augmentationCutMix  
##                                                 2.397e-02  
##              bp_valid^4:pretrainedTRUE:augmentationCutMix  
##                                                -7.490e-03  
##              bp_valid^5:pretrainedTRUE:augmentationCutMix  
##                                                -1.199e-02  
##              bp_valid^6:pretrainedTRUE:augmentationCutMix  
##                                                 1.557e-02  
##              bp_valid^7:pretrainedTRUE:augmentationCutMix  
##                                                -2.549e-03  
##              bp_valid^8:pretrainedTRUE:augmentationCutMix  
##                                                -3.646e-03

Let’s now look at model predictions to get a better sense. We can see a few things:

predictions = select(df,trans,arch,pretrained,label_smoothing,augmentation,bp_valid) %>%
  distinct()

predictions$predicted_acc = sin(predict(reduced_model, predictions))

predictions = predictions %>%
  arrange(-predicted_acc)

predictions %>%
  split(.$bp_valid)
## $`0.5`
## # A tibble: 168 × 7
##    trans arch     pretrained label_smoothing augmentation bp_valid predicted_acc
##    <lgl> <chr>    <lgl>      <lgl>           <fct>        <ord>            <dbl>
##  1 TRUE  resnet1… FALSE      FALSE           MixUp        0.5              0.840
##  2 FALSE resnet1… FALSE      FALSE           CutMix       0.5              0.834
##  3 TRUE  resnet1… FALSE      FALSE           CutMix       0.5              0.833
##  4 TRUE  resnet1… FALSE      TRUE            MixUp        0.5              0.832
##  5 FALSE resnet1… FALSE      FALSE           MixUp        0.5              0.829
##  6 FALSE resnet1… FALSE      TRUE            CutMix       0.5              0.826
##  7 TRUE  resnet5… FALSE      FALSE           CutMix       0.5              0.825
##  8 FALSE resnet5… FALSE      FALSE           CutMix       0.5              0.825
##  9 TRUE  resnet1… FALSE      TRUE            CutMix       0.5              0.825
## 10 FALSE resnet1… FALSE      TRUE            MixUp        0.5              0.821
## # … with 158 more rows
## 
## $`1`
## # A tibble: 168 × 7
##    trans arch     pretrained label_smoothing augmentation bp_valid predicted_acc
##    <lgl> <chr>    <lgl>      <lgl>           <fct>        <ord>            <dbl>
##  1 TRUE  resnet1… FALSE      FALSE           MixUp        1                0.901
##  2 FALSE resnet1… FALSE      FALSE           CutMix       1                0.895
##  3 TRUE  resnet1… FALSE      TRUE            MixUp        1                0.894
##  4 TRUE  resnet1… FALSE      FALSE           CutMix       1                0.894
##  5 FALSE resnet1… FALSE      FALSE           MixUp        1                0.892
##  6 TRUE  ig_resn… FALSE      FALSE           CutMix       1                0.888
##  7 FALSE resnet1… FALSE      TRUE            CutMix       1                0.887
##  8 TRUE  resnet1… FALSE      TRUE            CutMix       1                0.887
##  9 FALSE resnet1… FALSE      TRUE            MixUp        1                0.886
## 10 TRUE  resnet5… FALSE      FALSE           CutMix       1                0.884
## # … with 158 more rows
## 
## $`2`
## # A tibble: 168 × 7
##    trans arch     pretrained label_smoothing augmentation bp_valid predicted_acc
##    <lgl> <chr>    <lgl>      <lgl>           <fct>        <ord>            <dbl>
##  1 TRUE  ig_resn… FALSE      FALSE           CutMix       2                0.939
##  2 FALSE wide_re… FALSE      TRUE            None         2                0.938
##  3 TRUE  wide_re… FALSE      TRUE            None         2                0.935
##  4 TRUE  resnet50 FALSE      FALSE           CutMix       2                0.934
##  5 FALSE ig_resn… FALSE      FALSE           CutMix       2                0.934
##  6 TRUE  ig_resn… FALSE      TRUE            CutMix       2                0.934
##  7 FALSE resnet50 FALSE      FALSE           CutMix       2                0.931
##  8 FALSE wide_re… FALSE      FALSE           CutMix       2                0.930
##  9 TRUE  resnet50 FALSE      TRUE            CutMix       2                0.929
## 10 FALSE resnet50 FALSE      FALSE           MixUp        2                0.928
## # … with 158 more rows
## 
## $`5`
## # A tibble: 168 × 7
##    trans arch     pretrained label_smoothing augmentation bp_valid predicted_acc
##    <lgl> <chr>    <lgl>      <lgl>           <fct>        <ord>            <dbl>
##  1 TRUE  ig_resn… FALSE      FALSE           CutMix       5                0.958
##  2 FALSE wide_re… FALSE      TRUE            None         5                0.953
##  3 FALSE ig_resn… FALSE      FALSE           CutMix       5                0.953
##  4 TRUE  ig_resn… FALSE      TRUE            CutMix       5                0.953
##  5 TRUE  wide_re… FALSE      TRUE            None         5                0.951
##  6 FALSE ig_resn… FALSE      TRUE            CutMix       5                0.948
##  7 TRUE  resnet50 FALSE      FALSE           CutMix       5                0.948
##  8 TRUE  ig_resn… FALSE      TRUE            None         5                0.948
##  9 FALSE ig_resn… FALSE      FALSE           MixUp        5                0.947
## 10 FALSE wide_re… FALSE      FALSE           CutMix       5                0.945
## # … with 158 more rows
## 
## $`10`
## # A tibble: 168 × 7
##    trans arch     pretrained label_smoothing augmentation bp_valid predicted_acc
##    <lgl> <chr>    <lgl>      <lgl>           <fct>        <ord>            <dbl>
##  1 TRUE  ig_resn… FALSE      FALSE           CutMix       10               0.960
##  2 FALSE ig_resn… FALSE      FALSE           CutMix       10               0.956
##  3 TRUE  ig_resn… FALSE      TRUE            CutMix       10               0.956
##  4 FALSE ig_resn… FALSE      TRUE            CutMix       10               0.951
##  5 TRUE  resnet50 FALSE      FALSE           CutMix       10               0.950
##  6 FALSE wide_re… FALSE      TRUE            None         10               0.950
##  7 FALSE ig_resn… FALSE      FALSE           MixUp        10               0.949
##  8 TRUE  ig_resn… FALSE      TRUE            None         10               0.948
##  9 TRUE  wide_re… FALSE      TRUE            None         10               0.947
## 10 FALSE resnet50 FALSE      FALSE           CutMix       10               0.947
## # … with 158 more rows
## 
## $`20`
## # A tibble: 168 × 7
##    trans arch     pretrained label_smoothing augmentation bp_valid predicted_acc
##    <lgl> <chr>    <lgl>      <lgl>           <fct>        <ord>            <dbl>
##  1 TRUE  ig_resn… FALSE      FALSE           CutMix       20               0.961
##  2 FALSE ig_resn… FALSE      FALSE           CutMix       20               0.956
##  3 TRUE  ig_resn… FALSE      TRUE            CutMix       20               0.956
##  4 FALSE wide_re… FALSE      TRUE            None         20               0.955
##  5 TRUE  wide_re… FALSE      TRUE            None         20               0.952
##  6 FALSE ig_resn… FALSE      TRUE            CutMix       20               0.951
##  7 TRUE  resnet50 FALSE      FALSE           CutMix       20               0.951
##  8 TRUE  ig_resn… FALSE      TRUE            None         20               0.950
##  9 FALSE ig_resn… FALSE      FALSE           MixUp        20               0.950
## 10 FALSE resnet50 FALSE      FALSE           CutMix       20               0.948
## # … with 158 more rows
## 
## $`50`
## # A tibble: 168 × 7
##    trans arch     pretrained label_smoothing augmentation bp_valid predicted_acc
##    <lgl> <chr>    <lgl>      <lgl>           <fct>        <ord>            <dbl>
##  1 TRUE  ig_resn… FALSE      FALSE           CutMix       50               0.961
##  2 FALSE wide_re… FALSE      TRUE            None         50               0.957
##  3 FALSE ig_resn… FALSE      FALSE           CutMix       50               0.957
##  4 TRUE  ig_resn… FALSE      TRUE            CutMix       50               0.957
##  5 TRUE  wide_re… FALSE      TRUE            None         50               0.954
##  6 FALSE ig_resn… FALSE      TRUE            CutMix       50               0.952
##  7 TRUE  resnet50 FALSE      FALSE           CutMix       50               0.951
##  8 TRUE  ig_resn… FALSE      TRUE            None         50               0.951
##  9 FALSE ig_resn… FALSE      FALSE           MixUp        50               0.951
## 10 FALSE wide_re… FALSE      FALSE           CutMix       50               0.949
## # … with 158 more rows
## 
## $`100`
## # A tibble: 168 × 7
##    trans arch     pretrained label_smoothing augmentation bp_valid predicted_acc
##    <lgl> <chr>    <lgl>      <lgl>           <fct>        <ord>            <dbl>
##  1 TRUE  ig_resn… FALSE      FALSE           CutMix       100              0.960
##  2 FALSE wide_re… FALSE      TRUE            None         100              0.958
##  3 TRUE  wide_re… FALSE      TRUE            None         100              0.956
##  4 FALSE ig_resn… FALSE      FALSE           CutMix       100              0.956
##  5 TRUE  ig_resn… FALSE      TRUE            CutMix       100              0.956
##  6 TRUE  ig_resn… FALSE      TRUE            None         100              0.952
##  7 FALSE ig_resn… FALSE      TRUE            CutMix       100              0.951
##  8 FALSE ig_resn… FALSE      FALSE           MixUp        100              0.950
##  9 TRUE  resnet50 FALSE      FALSE           CutMix       100              0.950
## 10 FALSE wide_re… FALSE      FALSE           CutMix       100              0.949
## # … with 158 more rows
## 
## $`200`
## # A tibble: 168 × 7
##    trans arch     pretrained label_smoothing augmentation bp_valid predicted_acc
##    <lgl> <chr>    <lgl>      <lgl>           <fct>        <ord>            <dbl>
##  1 TRUE  ig_resn… FALSE      FALSE           CutMix       200              0.961
##  2 FALSE ig_resn… FALSE      FALSE           CutMix       200              0.957
##  3 TRUE  ig_resn… FALSE      TRUE            CutMix       200              0.957
##  4 FALSE wide_re… FALSE      TRUE            None         200              0.957
##  5 TRUE  wide_re… FALSE      TRUE            None         200              0.954
##  6 FALSE ig_resn… FALSE      TRUE            CutMix       200              0.952
##  7 TRUE  ig_resn… FALSE      TRUE            None         200              0.951
##  8 TRUE  resnet50 FALSE      FALSE           CutMix       200              0.951
##  9 FALSE ig_resn… FALSE      FALSE           MixUp        200              0.951
## 10 FALSE wide_re… FALSE      FALSE           CutMix       200              0.949
## # … with 158 more rows

Effect of sample quality

Now that we optimized training parameters, let’s evaluate the effect of sample quality. To do that, we did training using only 5 randomly chosen samples as training set, including 0-3 of the four lowest-quality samples per species. Quality was evaluated using two metrics: insert size or increase in T content throughout read length. We then evaluated, for each of the 5 samples per species left out of the training set, whether its prediction was correct.

We did 50 replicates ramdonly choosing the training set for each combination of quality metric and number of low-quality samples in the training set. Let’s now evaluate the results. Let’s start by reading the data.

df = read_csv('sample_quality.csv')[-1]
## New names:
## Rows: 93418 Columns: 13
## ── Column specification
## ──────────────────────────────────────────────────────── Delimiter: "," chr
## (6): bp_training, samples_training, qual_metric, sample_valid, valid_act... dbl
## (6): ...1, kmer_size, replicate, bp_valid, n_samp_training, n_lowqual_tr... lgl
## (1): valid_lowqual
## ℹ Use `spec()` to retrieve the full column specification for this data. ℹ
## Specify the column types or set `show_col_types = FALSE` to quiet this message.
## • `` -> `...1`
df = df %>%
  mutate(correct_pred = valid_actual == valid_prediction)

df

It seems that in general including some low quality samples (by the variation in content metric) may improve high-quality samples a little bit, but only increases variation of low quality samples instead of clearly improving them.

p = df %>%
  filter(qual_metric == 'high_c_sd') %>%
  group_by(replicate, sample_valid, n_lowqual_training) %>%
  filter(bp_valid == min(bp_valid)) %>%
  group_by(replicate, n_lowqual_training, valid_lowqual) %>%
  summarize(mean_acc = mean(correct_pred)) %>%
  mutate(valid_lowqual = c('TRUE' = 'Low quality samples as validation', 'FALSE' = 'High quality samples as validation')[as.character(valid_lowqual)]) %>%
  ggplot() +
  geom_histogram(aes(x = mean_acc), boundary = 1) +
  scale_y_continuous(sec.axis = sec_axis('identity', name = 'Number of low quality samples in training set',breaks = NULL, labels = NULL, guide = NULL)) + 
  scale_x_continuous(limits = c(0,1)) +
  xlab('Average validation accuracy across replicates') +
  ylab('Number of samples') +
  labs(title = 'Sequencing quality determined by variation in GC content') + 
  facet_grid(n_lowqual_training~valid_lowqual) +
  theme_few() +
  theme(strip.text.y = element_text(angle=0))
## `summarise()` has grouped output by 'replicate', 'n_lowqual_training'. You can
## override using the `.groups` argument.
p
## `stat_bin()` using `bins = 30`. Pick better value with `binwidth`.

ggsave(filename = 'quality_content.pdf',plot =p,device='pdf',path = 'paper_images',width = 7,height = 5,units = 'in')
## `stat_bin()` using `bins = 30`. Pick better value with `binwidth`.

The effect is less pronounced for average insert size

p = df %>%
  filter(qual_metric == 'low_size') %>%
  group_by(replicate, sample_valid, n_lowqual_training) %>%
  filter(bp_valid == max(bp_valid)) %>%
  group_by(replicate, n_lowqual_training, valid_lowqual) %>%
  summarize(mean_acc = mean(correct_pred)) %>%
  mutate(valid_lowqual = c('TRUE' = 'Low quality samples as validation', 'FALSE' = 'High quality samples as validation')[as.character(valid_lowqual)]) %>%
  ggplot() +
  geom_histogram(aes(x = mean_acc), boundary = 1) +
  scale_y_continuous(sec.axis = sec_axis('identity', name = 'Number of low quality samples in training set',breaks = NULL, labels = NULL, guide = NULL)) + 
  scale_x_continuous(limits = c(0,1)) +
  xlab('Average validation accuracy across replicates') +
  ylab('Number of samples') +
  labs(title = 'Sequencing quality determined by insert size') + 
  facet_grid(n_lowqual_training~valid_lowqual) +
  theme_few() +
  theme(strip.text.y = element_text(angle=0))
## `summarise()` has grouped output by 'replicate', 'n_lowqual_training'. You can
## override using the `.groups` argument.
p
## `stat_bin()` using `bins = 30`. Pick better value with `binwidth`.

ggsave(filename = 'quality_size.pdf',plot =p,device='pdf',path = 'paper_images',width = 7,height = 5,units = 'in')
## `stat_bin()` using `bins = 30`. Pick better value with `binwidth`.

What if we order all samples by their validation accuracy and compare to the quality metrics, what do we see?

df_info = read_csv('sample_info_stats.csv')[-1]
## New names:
## Rows: 100 Columns: 11
## ── Column specification
## ──────────────────────────────────────────────────────── Delimiter: "," chr
## (7): species, collector, collection, country, dna_concentration, library... dbl
## (4): ...1, sample_number, insert_size, content_sd
## ℹ Use `spec()` to retrieve the full column specification for this data. ℹ
## Specify the column types or set `show_col_types = FALSE` to quiet this message.
## • `` -> `...1`
df_info

There seems to be a weak negative correlation between variation in content and accuracy, but many samples that seem to be good with this metric have always low accuracy.

df %>%
  filter(qual_metric == 'high_c_sd') %>%
  group_by(replicate, sample_valid, n_lowqual_training) %>%
  filter(bp_valid == max(bp_valid)) %>%
  group_by(sample_valid, n_lowqual_training) %>%
  summarize(mean_acc = mean(correct_pred)) %>%
  left_join(df_info,by = c('sample_valid' = 'library_id')) %>%
  ggplot() +
  scale_x_sqrt() +
  geom_jitter(aes(x = content_sd, y = mean_acc, color = species),width = 0, height = 0.05) +
  scale_color_viridis_d(option = 'turbo') +
  facet_wrap(~n_lowqual_training)
## `summarise()` has grouped output by 'sample_valid'. You can override using the
## `.groups` argument.

Again, this is less pronounced for insert size

df %>%
  filter(qual_metric == 'low_size') %>%
  group_by(replicate, sample_valid, n_lowqual_training) %>%
  filter(bp_valid == min(bp_valid)) %>%
  group_by(sample_valid, n_lowqual_training) %>%
  summarize(mean_acc = mean(correct_pred)) %>%
  left_join(df_info,by = c('sample_valid' = 'library_id')) %>%
  ggplot() +
  geom_jitter(aes(x = insert_size, y = mean_acc, color = species),width = 0, height = 0.05) +
  scale_color_viridis_d(option = 'turbo') +
  facet_wrap(~n_lowqual_training)
## `summarise()` has grouped output by 'sample_valid'. You can override using the
## `.groups` argument.

What is the relationship between DNA concentration and library quality?

First, let’s plot against standard deviation.

p1 = df_info %>%
  mutate(dna_c = ifelse(dna_concentration == 'too high', 200, dna_concentration),
         dna_c = as.numeric(dna_c),
         dna_c = ifelse(dna_c == 0, 0.05, dna_c)) %>%
  ggplot() +
  geom_point(aes(dna_c, content_sd)) +
  scale_y_log10(name = 'Standard deviation in base content') +
  scale_x_log10(name = 'DNA yield (ng/uL)', breaks = c(0.05,0.1,1,10,100,200), labels = c('too\nlow', 0.1, 1, 10, 100, 'too\nhigh')) +
  theme_few()

p1

Now, against insert size

p2 = df_info %>%
  mutate(dna_c = ifelse(dna_concentration == 'too high', 200, dna_concentration),
         dna_c = as.numeric(dna_c),
         dna_c = ifelse(dna_c == 0, 0.05, dna_c)) %>%
  ggplot() +
  geom_point(aes(dna_c, insert_size)) +
  scale_y_continuous(name = 'Insert size (bp)') +
  scale_x_log10(name = 'DNA yield (ng/uL)', breaks = c(0.05,0.1,1,10,100,200), labels = c('too\nlow', 0.1, 1, 10, 100, 'too\nhigh')) +
  theme_few()

p2

p = cowplot::plot_grid(p1,p2,ncol = 1)

ggsave(filename = 'yield_vs_quality.pdf',plot =p,device='pdf',path = 'paper_images',width = 5,height = 8,units = 'in')

Bottomline: as long as the majority of the samples for each species are high-quality, having low-quality samples in the training set should not cause much trouble and might even improve inference for some low-quality samples.

Number of samples per species

Now let’s evaluate the effect of number of samples per species.

df = read_csv('n_training.csv')[-1]
## New names:
## Rows: 164229 Columns: 10
## ── Column specification
## ──────────────────────────────────────────────────────── Delimiter: "," chr
## (5): bp_training, samples_training, sample_valid, valid_actual, valid_pr... dbl
## (5): ...1, kmer_size, replicate, bp_valid, n_samp_training
## ℹ Use `spec()` to retrieve the full column specification for this data. ℹ
## Specify the column types or set `show_col_types = FALSE` to quiet this message.
## • `` -> `...1`
df = df %>%
  mutate(correct_pred = valid_actual == valid_prediction)

df

Does the number of samples used in training impact the validation accuracy? Let’s plot one panel for each sample. It seems it does.

p = df %>%
  group_by(n_samp_training, bp_valid, sample_valid, valid_actual) %>%
  summarize(mean_acc = mean(correct_pred)) %>%
  ggplot() +
  #geom_jitter(aes(x = n_samp_training/10, y = mean_acc)) +
  geom_boxplot(aes(x = n_samp_training/10, y = mean_acc, group = n_samp_training/10)) +
  facet_wrap(valid_actual~sample_valid) +
  theme_few()
## `summarise()` has grouped output by 'n_samp_training', 'bp_valid',
## 'sample_valid'. You can override using the `.groups` argument.
p

Let’s now plot only the average accuracy for each sample across replicates, with each sample represented by a line.

It seems that more samples in the training set does help, but for most cases about 4 samples is alrady pretty good. Let’s plot coloring by species

df %>%
  group_by(n_samp_training, sample_valid, valid_actual) %>%
  summarize(mean_acc = mean(correct_pred)) %>%
  mutate(valid_actual = fct_reorder(valid_actual,mean_acc)) %>%
  ggplot() +
  geom_line(aes(x = n_samp_training/10, y = mean_acc, group = sample_valid, color = valid_actual, linetype = valid_actual)) +
  scale_color_manual(values = c(few_pal('Dark')(5),few_pal('Dark')(5))) +
  scale_linetype_manual(values = rep(1:2,each = 5)) +
  theme_few()
## `summarise()` has grouped output by 'n_samp_training', 'sample_valid'. You can
## override using the `.groups` argument.

Now let’s try to use line type by sample quality instead.

df_plot = df %>%
  group_by(n_samp_training, sample_valid, valid_actual) %>%
  summarize(mean_acc = mean(correct_pred)) %>%
  mutate(valid_actual = fct_reorder(valid_actual,mean_acc)) %>%
  left_join(df_info %>% 
              mutate(sample_valid = paste0('S-',sample_number)) %>% 
              mutate(dna_concentration = ifelse(dna_concentration == 'too high',150,dna_concentration)) %>%
              mutate(dna_concentration = as.numeric(dna_concentration)) %>%
              mutate(highqual = dna_concentration >= quantile(dna_concentration,probs=0.5)) %>%
              select(sample_valid, highqual))
## `summarise()` has grouped output by 'n_samp_training', 'sample_valid'. You can
## override using the `.groups` argument.
## Joining, by = "sample_valid"
df_ribbon = df_plot %>%
  group_by(n_samp_training) %>%
  summarise(q1 = quantile(mean_acc,0.25),
            median = median(mean_acc),
            q3 = quantile(mean_acc, 0.75))


p =  ggplot(df_plot) +
  stat_summary(aes(x = n_samp_training/10, y = mean_acc), fill = 'pink', fun.max = function(x){quantile(x,0.75)},fun.min = function(x){quantile(x,0.25)}, geom='ribbon') +
  geom_line(aes(x = n_samp_training/10, y = mean_acc, group = sample_valid, linetype = highqual), alpha = 0.5, size = 0.25) +
  stat_summary(aes(x = n_samp_training/10, y = mean_acc), color = 'red', size = 0.5, fun = 'median', geom='line') +
  scale_linetype_manual(values = c('TRUE' = "solid", 'FALSE' = "51"), name = 'DNA yield', labels = c('TRUE' = 'High', 'FALSE' = 'Low')) +
  scale_x_continuous(breaks=1:7) +
  ylab('Average validation accuracy') +
  xlab('Training samples per species') +
  theme_few(base_size = 6) +
  theme(legend.key.size = unit(0.2, "cm"))

p

ggsave(filename = 'n_samples.png',plot =p,device='png',path = 'paper_images',width = 16,height = 5,units = 'cm',dpi = 2400)